{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "62787244",
   "metadata": {},
   "source": [
    "# Microbenchmarks on GPU\n",
    "This is a notebook for microbenchmarks running on GPU. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1c3a15d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "from pyspark.conf import SparkConf\n",
    "from time import time\n",
    "import os\n",
    "# Change to your cluster ip:port and directories\n",
    "SPARK_MASTER_URL = os.getenv(\"SPARK_MASTER_URL\", \"spark:your-ip:port\")\n",
    "RAPIDS_JAR = os.getenv(\"RAPIDS_JAR\", \"/your-path/rapids-4-spark_2.12-25.04.0.jar\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b10a2ad1",
   "metadata": {},
   "source": [
    "Run the microbenchmark with retryTimes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0c3536ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "def runMicroBenchmark(spark, appName, query, retryTimes):\n",
    "    count = 0\n",
    "    total_time = 0\n",
    "    # You can print the physical plan of each query\n",
    "    # spark.sql(query).explain()\n",
    "    while count < retryTimes:\n",
    "        start = time()\n",
    "        spark.sql(query).show(5)\n",
    "        end = time()\n",
    "        total_time += round(end - start, 2)\n",
    "        count = count + 1\n",
    "        print(\"Retry times : {}, \".format(count) + appName + \" microbenchmark takes {} seconds\".format(round(end - start, 2)))\n",
    "    print(appName + \" microbenchmark takes average {} seconds after {} retries\".format(round(total_time/retryTimes),retryTimes))\n",
    "    with open('result.txt', 'a') as file:\n",
    "        file.write(\"{},{},{}\\n\".format(appName, round(total_time/retryTimes), retryTimes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "975717da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# You need to update with your real hardware resource \n",
    "driverMem = os.getenv(\"DRIVER_MEM\", \"50g\")\n",
    "executorMem = os.getenv(\"EXECUTOR_MEM\", \"16g\")\n",
    "maxPartionBytes = os.getenv(\"MAX_PARTITION_BYTES\", \"1g\")\n",
    "pinnedPoolSize = os.getenv(\"PINNED_POOL_SIZE\", \"8g\")\n",
    "concurrentGpuTasks = os.getenv(\"CONCURRENT_GPU_TASKS\", \"4\")\n",
    "executorCores = int(os.getenv(\"EXECUTOR_CORES\", \"16\"))\n",
    "eventlogDir = \"file:\"+os.getenv(\"EVENTLOG_DIR\")\n",
    "gpuPerExecutor = 1/executorCores\n",
    "# Common spark settings\n",
    "conf = SparkConf()\n",
    "conf.setMaster(SPARK_MASTER_URL)\n",
    "conf.setAppName(\"Microbenchmark on GPU\")\n",
    "conf.set(\"spark.driver.memory\", driverMem)\n",
    "## The tasks will run on GPU memory, so there is no need to set a high host memory\n",
    "conf.set(\"spark.executor.memory\", executorMem)\n",
    "## The tasks will run on GPU cores, so there is no need to use many cpu cores\n",
    "conf.set(\"spark.executor.cores\", executorCores)\n",
    "conf.set(\"spark.locality.wait\", \"0\")\n",
    "conf.set(\"spark.sql.files.maxPartitionBytes\", maxPartionBytes) \n",
    "conf.set(\"spark.dynamicAllocation.enabled\", \"false\") \n",
    "conf.set(\"spark.sql.adaptive.enabled\", \"true\") \n",
    "\n",
    "# Plugin settings\n",
    "conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "# 4 tasks will run concurrently per GPU\n",
    "conf.set(\"spark.rapids.sql.concurrentGpuTasks\", concurrentGpuTasks)\n",
    "# Pinned 8g host memory to transfer data between GPU and host memory\n",
    "conf.set(\"spark.rapids.memory.pinnedPool.size\", pinnedPoolSize)\n",
    "# 16 tasks will run concurrently per executor, as we set spark.executor.cores=16\n",
    "conf.set(\"spark.task.resource.gpu.amount\", gpuPerExecutor) \n",
    "conf.set(\"spark.rapids.sql.enabled\", \"true\") \n",
    "conf.set(\"spark.plugins\", \"com.nvidia.spark.SQLPlugin\")\n",
    "conf.set(\"spark.rapids.sql.variableFloatAgg.enabled\", \"true\")\n",
    "conf.set(\"spark.driver.extraClassPath\", RAPIDS_JAR)\n",
    "conf.set(\"spark.executor.extraClassPath\", RAPIDS_JAR)\n",
    "conf.set(\"spark.jars\", RAPIDS_JAR)\n",
    "conf.set(\"spark.eventLog.enabled\", \"true\")\n",
    "conf.set(\"spark.eventLog.dir\", eventlogDir)\n",
    "# Create spark session\n",
    "spark = SparkSession.builder.config(conf=conf).getOrCreate()\n",
    "# Load dataframe and create tempView\n",
    "# You need to update data path to your real path!\n",
    "dataRoot = os.getenv(\"DATA_ROOT\", \"/data\")\n",
    "spark.read.parquet(dataRoot + \"/customer.dat\").createOrReplaceTempView(\"customer\")\n",
    "spark.read.parquet(dataRoot + \"/store_sales.dat\").createOrReplaceTempView(\"store_sales\")\n",
    "spark.read.parquet(dataRoot + \"/catalog_sales.dat\").createOrReplaceTempView(\"catalog_sales\")\n",
    "spark.read.parquet(dataRoot + \"/web_sales.dat\").createOrReplaceTempView(\"web_sales\")\n",
    "spark.read.parquet(dataRoot + \"/item.dat\").createOrReplaceTempView(\"item\")\n",
    "spark.read.parquet(dataRoot + \"/date_dim.dat\").createOrReplaceTempView(\"date_dim\")\n",
    "print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7136eb63",
   "metadata": {},
   "source": [
    "### Expand&HashAggregate\n",
    "This is a microbenchmark about Expand&HashAggregate expressions running on the GPU. The query calculates the distinct value of some dimension columns and average birth year by different c_salutation of customers after grouping by c_current_hdemo_sk. You will see about 10x speedups in this query. Because an additional shuffle involved by the repartition operator in CPU mode. And GPUExpand and GPUHashAggregate is much faster than Expand and HashAggregate because GPU algorithms allow us to parallelize the computation and we can utilize most of the GPU cores. The tasks' duration in the third stage is less than one second but will cost 20x-40x while running on CPU. There will be a more significant performance improvement along with the increasing number of count distinct columns and aggregate functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dd12d749",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = '''\n",
    "select c_current_hdemo_sk,\n",
    "count(DISTINCT if(c_salutation==\"Ms.\",c_salutation,null)) as c1,\n",
    "count(DISTINCT if(c_salutation==\"Mr.\",c_salutation,null)) as c12,\n",
    "count(DISTINCT if(c_salutation==\"Dr.\",c_salutation,null)) as c13,\n",
    "\n",
    "count(DISTINCT if(c_salutation==\"Ms.\",c_first_name,null)) as c2,\n",
    "count(DISTINCT if(c_salutation==\"Mr.\",c_first_name,null)) as c22,\n",
    "count(DISTINCT if(c_salutation==\"Dr.\",c_first_name,null)) as c23,\n",
    "\n",
    "count(DISTINCT if(c_salutation==\"Ms.\",c_last_name,null)) as c3,\n",
    "count(DISTINCT if(c_salutation==\"Mr.\",c_last_name,null)) as c32,\n",
    "count(DISTINCT if(c_salutation==\"Dr.\",c_last_name,null)) as c33,\n",
    "\n",
    "count(DISTINCT if(c_salutation==\"Ms.\",c_birth_country,null)) as c4,\n",
    "count(DISTINCT if(c_salutation==\"Mr.\",c_birth_country,null)) as c42,\n",
    "count(DISTINCT if(c_salutation==\"Dr.\",c_birth_country,null)) as c43,\n",
    "\n",
    "count(DISTINCT if(c_salutation==\"Ms.\",c_email_address,null)) as c5,\n",
    "count(DISTINCT if(c_salutation==\"Mr.\",c_email_address,null)) as c52,\n",
    "count(DISTINCT if(c_salutation==\"Dr.\",c_email_address,null)) as c53,\n",
    "\n",
    "count(DISTINCT if(c_salutation==\"Ms.\",c_login,null)) as c6,\n",
    "count(DISTINCT if(c_salutation==\"Mr.\",c_login,null)) as c62,\n",
    "count(DISTINCT if(c_salutation==\"Dr.\",c_login,null)) as c63,\n",
    "\n",
    "count(DISTINCT if(c_salutation==\"Ms.\",c_preferred_cust_flag,null)) as c7,\n",
    "count(DISTINCT if(c_salutation==\"Mr.\",c_preferred_cust_flag,null)) as c72,\n",
    "count(DISTINCT if(c_salutation==\"Dr.\",c_preferred_cust_flag,null)) as c73,\n",
    "\n",
    "count(DISTINCT if(c_salutation==\"Ms.\",c_birth_month,null)) as c8,\n",
    "count(DISTINCT if(c_salutation==\"Mr.\",c_birth_month,null)) as c82,\n",
    "count(DISTINCT if(c_salutation==\"Dr.\",c_birth_month,null)) as c83,\n",
    "\n",
    "avg(if(c_salutation==\"Ms.\",c_birth_year,null)) as avg1,\n",
    "avg(if(c_salutation==\"Mr.\",c_birth_year,null)) as avg2,\n",
    "avg(if(c_salutation==\"Dr.\",c_birth_year,null)) as avg3,\n",
    "avg(if(c_salutation==\"Miss.\",c_birth_year,null)) as avg4,\n",
    "avg(if(c_salutation==\"Mrs.\",c_birth_year,null)) as avg5,\n",
    "avg(if(c_salutation==\"Sir.\",c_birth_year,null)) as avg6,\n",
    "avg(if(c_salutation==\"Professor.\",c_birth_year,null)) as avg7,\n",
    "avg(if(c_salutation==\"Teacher.\",c_birth_year,null)) as avg8,\n",
    "avg(if(c_salutation==\"Agent.\",c_birth_year,null)) as avg9,\n",
    "avg(if(c_salutation==\"Director.\",c_birth_year,null)) as avg10\n",
    "from customer group by c_current_hdemo_sk\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2e105bf8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n",
      "|c_current_hdemo_sk| c1|c12|c13| c2|c22|c23| c3|c32|c33| c4|c42|c43| c5|c52| c53| c6|c62|c63| c7|c72|c73| c8|c82|c83|              avg1|              avg2|              avg3|avg4|              avg5|avg6|avg7|avg8|avg9|avg10|\n",
      "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n",
      "|              1238|  1|  1|  1|284|255|562|358|467|772|194|203|211|452|664|1157|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.2444933920706|1958.8547655068078|1957.2870771899393|null| 1958.042643923241|null|null|null|null| null|\n",
      "|              6658|  1|  1|  1|318|253|541|384|492|752|190|203|210|516|647|1115|  0|  0|  0|  2|  2|  2| 12| 12| 12|1959.0155945419103|1958.9720930232559|1958.0089525514773|null|1959.2618025751074|null|null|null|null| null|\n",
      "|              1088|  1|  1|  1|302|263|547|374|476|736|191|206|210|487|648|1074|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.7084188911704|1959.1323076923077|1957.2780898876404|null|1958.5641025641025|null|null|null|null| null|\n",
      "|              4818|  1|  1|  1|276|248|542|368|514|747|183|204|211|460|691|1093|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.8954248366013|1958.1313131313132|1957.5018315018315|null|1958.0252293577983|null|null|null|null| null|\n",
      "|               148|  1|  1|  1|309|260|562|392|501|772|187|207|211|488|668|1154|  0|  0|  0|  2|  2|  2| 12| 12| 12| 1956.219008264463|1958.9161676646706|1957.8076256499132|null|1958.3412017167382|null|null|null|null| null|\n",
      "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n",
      "only showing top 5 rows\n",
      "\n",
      "Retry times : 1, Expand&HashAggregate microbenchmark takes 11.13 seconds\n",
      "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n",
      "|c_current_hdemo_sk| c1|c12|c13| c2|c22|c23| c3|c32|c33| c4|c42|c43| c5|c52| c53| c6|c62|c63| c7|c72|c73| c8|c82|c83|              avg1|              avg2|              avg3|avg4|              avg5|avg6|avg7|avg8|avg9|avg10|\n",
      "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n",
      "|              1238|  1|  1|  1|284|255|562|358|467|772|194|203|211|452|664|1157|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.2444933920706|1958.8547655068078|1957.2870771899393|null| 1958.042643923241|null|null|null|null| null|\n",
      "|              6658|  1|  1|  1|318|253|541|384|492|752|190|203|210|516|647|1115|  0|  0|  0|  2|  2|  2| 12| 12| 12|1959.0155945419103|1958.9720930232559|1958.0089525514773|null|1959.2618025751074|null|null|null|null| null|\n",
      "|              4818|  1|  1|  1|276|248|542|368|514|747|183|204|211|460|691|1093|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.8954248366013|1958.1313131313132|1957.5018315018315|null|1958.0252293577983|null|null|null|null| null|\n",
      "|              1088|  1|  1|  1|302|263|547|374|476|736|191|206|210|487|648|1074|  0|  0|  0|  2|  2|  2| 12| 12| 12|1957.7084188911704|1959.1323076923077|1957.2780898876404|null|1958.5641025641025|null|null|null|null| null|\n",
      "|               148|  1|  1|  1|309|260|562|392|501|772|187|207|211|488|668|1154|  0|  0|  0|  2|  2|  2| 12| 12| 12| 1956.219008264463|1958.9161676646706|1957.8076256499132|null|1958.3412017167382|null|null|null|null| null|\n",
      "+------------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+------------------+------------------+------------------+----+------------------+----+----+----+----+-----+\n",
      "only showing top 5 rows\n",
      "\n",
      "Retry times : 2, Expand&HashAggregate microbenchmark takes 7.74 seconds\n",
      "Expand&HashAggregate microbenchmark takes average 9 seconds after 2 retries\n"
     ]
    }
   ],
   "source": [
    "# Run microbenchmark with n retry time\n",
    "runMicroBenchmark(spark,\"Expand&HashAggregate\",query,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f50ec183",
   "metadata": {},
   "source": [
    "### Windowing(without data skew)\n",
    "This is a microbenchmark about windowing expressions running on GPU mode. The sub-query calculates the average ss_sales_price of a fixed window function partition by ss_customer_sk, and the parent query calculates the average price of the sub-query grouping by each customer. You will see about 25x speedups in this query. The speedup mainly comes from GPUSort/GPUWindow/GPUHashAggregate. The avg aggregation function evaluates all rows which are generated by the sub-query's window function. There will be a more significant performance improvement along with the increasing number of sub-query aggregate functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "31bd0635",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "query = '''\n",
    "select ss_customer_sk,avg(avg_price) as avg_price\n",
    "from\n",
    "(\n",
    "SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\n",
    "FROM store_sales\n",
    "where ss_customer_sk is not null\n",
    ") group by ss_customer_sk order by 2 desc \n",
    "'''\n",
    "print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f9e93983",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------+------------------+\n",
      "|ss_customer_sk|         avg_price|\n",
      "+--------------+------------------+\n",
      "|      15924921|52.375180502283705|\n",
      "|      24796404| 52.21073975966333|\n",
      "|      14299506| 52.16263537127018|\n",
      "|      27571451|52.156112032252395|\n",
      "|      10174233| 52.06401030721082|\n",
      "+--------------+------------------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Retry times : 1, Windowing without skew microbenchmark takes 11.39 seconds\n",
      "+--------------+-----------------+\n",
      "|ss_customer_sk|        avg_price|\n",
      "+--------------+-----------------+\n",
      "|      15924921|52.53781291335107|\n",
      "|      24796404|52.39683466140243|\n",
      "|      27571451|52.18830023174899|\n",
      "|      14299506|52.10829141087412|\n",
      "|      10174233|51.92766214818386|\n",
      "+--------------+-----------------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Retry times : 2, Windowing without skew microbenchmark takes 9.53 seconds\n",
      "Windowing without skew microbenchmark takes average 10 seconds after 2 retries\n"
     ]
    }
   ],
   "source": [
    "# Run microbenchmark with n retry time\n",
    "runMicroBenchmark(spark,\"Windowing without skew\",query,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcf08e47",
   "metadata": {},
   "source": [
    "### Windowing(with data skew)\n",
    "Data skew is caused by many null values in the ss_customer_sk column. You will see about 80x speedups in this query. The heavier skew task a query has, the more improved performance we will get because GPU parallelizes the computation, CPU is limited to just a single core because of how the algorithms are written."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2b9d223c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "query = '''\n",
    "select ss_customer_sk,avg(avg_price) as avg_price\n",
    "from\n",
    "(\n",
    "SELECT ss_customer_sk ,avg(ss_sales_price) OVER (PARTITION BY ss_customer_sk order by ss_sold_date_sk ROWS BETWEEN 50 PRECEDING AND 50 FOLLOWING ) as avg_price\n",
    "FROM store_sales\n",
    ") group by ss_customer_sk order by 2 desc \n",
    "'''\n",
    "print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0d7c65ee",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------+------------------+\n",
      "|ss_customer_sk|         avg_price|\n",
      "+--------------+------------------+\n",
      "|      24796404| 52.40675225109215|\n",
      "|      27571451|52.396675141359374|\n",
      "|      15924921| 52.30557497833058|\n",
      "|      10174233|52.088916933379096|\n",
      "|      14299506|51.995045713009794|\n",
      "+--------------+------------------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Retry times : 1, Windowing with skew microbenchmark takes 17.46 seconds\n",
      "+--------------+------------------+\n",
      "|ss_customer_sk|         avg_price|\n",
      "+--------------+------------------+\n",
      "|      24796404|52.403564615099896|\n",
      "|      15924921|52.262694645994465|\n",
      "|      27571451| 52.14256448618127|\n",
      "|      10174233| 52.11346591610992|\n",
      "|      14299506| 51.99180221022445|\n",
      "+--------------+------------------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Retry times : 2, Windowing with skew microbenchmark takes 16.63 seconds\n",
      "Windowing with skew microbenchmark takes average 17 seconds after 2 retries\n"
     ]
    }
   ],
   "source": [
    "# Run microbenchmark with n retry time\n",
    "runMicroBenchmark(spark,\"Windowing with skew\",query,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53c0ed28",
   "metadata": {},
   "source": [
    "### Intersection\n",
    "This is a microbenchmark about intersection operation running on GPU mode. The query calculates items in the same brand, class, and category that are sold in all three sales channels in two consecutive years. You will see about 10x speedups in this query. This is a competition between high cardinality SortMergeJoin vs GpuShuffleHashJoin. The mainly improved performance comes from two SortMergeJoin(s) in this query running on CPU get converted to GpuShuffleHashJoin running on GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "643c2e8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = '''\n",
    "select i_item_sk ss_item_sk\n",
    " from item,\n",
    "    (select iss.i_brand_id brand_id, iss.i_class_id class_id, iss.i_category_id category_id\n",
    "     from store_sales, item iss, date_dim d1\n",
    "     where ss_item_sk = iss.i_item_sk\n",
    "                    and ss_sold_date_sk = d1.d_date_sk\n",
    "       and d1.d_year between 1999 AND 1999 + 2\n",
    "   intersect\n",
    "     select ics.i_brand_id, ics.i_class_id, ics.i_category_id\n",
    "     from catalog_sales, item ics, date_dim d2\n",
    "     where cs_item_sk = ics.i_item_sk\n",
    "       and cs_sold_date_sk = d2.d_date_sk\n",
    "       and d2.d_year between 1999 AND 1999 + 2\n",
    "   intersect\n",
    "     select iws.i_brand_id, iws.i_class_id, iws.i_category_id\n",
    "     from web_sales, item iws, date_dim d3\n",
    "     where ws_item_sk = iws.i_item_sk\n",
    "       and ws_sold_date_sk = d3.d_date_sk\n",
    "       and d3.d_year between 1999 AND 1999 + 2) x\n",
    " where i_brand_id = brand_id\n",
    "   and i_class_id = class_id\n",
    "   and i_category_id = category_id\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "61bc2260",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+\n",
      "|ss_item_sk|\n",
      "+----------+\n",
      "|      4323|\n",
      "|      4324|\n",
      "|      4325|\n",
      "|      4327|\n",
      "|      4328|\n",
      "+----------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Retry times : 1, NDS Q14a subquery microbenchmark takes 6.71 seconds\n",
      "+----------+\n",
      "|ss_item_sk|\n",
      "+----------+\n",
      "|     14103|\n",
      "|     14104|\n",
      "|     14105|\n",
      "|     14107|\n",
      "|     14108|\n",
      "+----------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Retry times : 2, NDS Q14a subquery microbenchmark takes 6.11 seconds\n",
      "NDS Q14a subquery microbenchmark takes average 6 seconds after 2 retries\n"
     ]
    }
   ],
   "source": [
    "# Run microbenchmark with n retry time\n",
    "runMicroBenchmark(spark,\"NDS Q14a subquery\",query,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1346d126",
   "metadata": {},
   "source": [
    "### Crossjoin\n",
    "This is a microbenchmark for a 1-million rows crossjoin with itself. You will see about 10x speedups in this query. The mainly improved performance comes from converting BroadcastNestedLoogJoin running on CPU to GpuBroadcastNestedLoogJoin running on GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "286ea45d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scanning and writing parquet cost : 5.31 seconds\n",
      "--------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "start = time() \n",
    "spark.read.parquet(dataRoot + \"/customer.dat\").limit(1000000).write.format(\"parquet\").mode(\"overwrite\").save(\"/data/tmp/customer1m\")\n",
    "end = time()\n",
    "# Parquet file scanning and writing will be about 3 times faster running on GPU\n",
    "print(\"scanning and writing parquet cost : {} seconds\".format(round(end - start, 2)))\n",
    "spark.read.parquet(\"/data/tmp/customer1m\").repartition(200).createOrReplaceTempView(\"costomer_df_1_million\")\n",
    "query = '''\n",
    "select count(*) from costomer_df_1_million c1 inner join costomer_df_1_million c2 on c1.c_customer_sk>c2.c_customer_sk\n",
    "'''\n",
    "print(\"-\"*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f41b8d54",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+------------+\n",
      "|    count(1)|\n",
      "+------------+\n",
      "|499999500000|\n",
      "+------------+\n",
      "\n",
      "Retry times : 1, Crossjoin microbenchmark takes 6.7 seconds\n",
      "+------------+\n",
      "|    count(1)|\n",
      "+------------+\n",
      "|499999500000|\n",
      "+------------+\n",
      "\n",
      "Retry times : 2, Crossjoin microbenchmark takes 6.37 seconds\n",
      "Crossjoin microbenchmark takes average 7 seconds after 2 retries\n"
     ]
    }
   ],
   "source": [
    "# Run microbenchmark with n retry time\n",
    "runMicroBenchmark(spark,\"Crossjoin\",query,2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06b351e6-b7bd-4063-a20b-fe4fd71221f9",
   "metadata": {},
   "source": [
    "### HashJoin\n",
    "This is a microbenchmark for a HashJoin. The query on GPU will be more than 10x times faster than CPU based on the cluster in the readme."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "191d0c9a-2d3a-40f4-89aa-f61dab5caa90",
   "metadata": {},
   "outputs": [],
   "source": [
    "spark.read.parquet(dataRoot + \"/store_sales.dat\").createOrReplaceTempView(\"store_sales\")\n",
    "spark.read.parquet(dataRoot + \"/store_returns.dat\").createOrReplaceTempView(\"store_returns\")\n",
    "\n",
    "print(\"-\"*50)\n",
    "query = '''\n",
    "select  sum(store_sales.ss_ext_wholesale_cost)\n",
    "from store_sales\n",
    "join store_returns on (ss_item_sk = sr_item_sk) and (ss_addr_sk=sr_addr_sk)\n",
    "'''\n",
    "runMicroBenchmark(spark,\"HashJoin\",query,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc2092e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "spark.stop()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
